In [1]:
import sys
sys.path.append('../src/mane/prototype/')
import numpy as np
import graph as g
import pickle as p

from sklearn.preprocessing import normalize, scale, MultiLabelBinarizer
from sklearn.metrics import f1_score
from sklearn.multiclass import OneVsRestClassifier
from sklearn.svm import LinearSVC
from sklearn.linear_model import LogisticRegression
from sklearn.linear_model import LogisticRegressionCV

In [3]:
# Load weight
with open('../src/mane/prototype/embeddings/BC3047.weights', 'rb') as f:
    w = p.load(f)
# Load graph
bc = g.graph_from_pickle('../src/mane/data/blogcatalog3.graph', 
                         '../src/mane/data/blogcatalog3.community')

In [4]:
emb = (w[0] + w[1]) / 2
emb = normalize(emb)

In [5]:
(w[0] + w[1])[0]


Out[5]:
array([ -2.69344091e-01,   8.65176857e-01,   4.68841523e-01,
         4.20090824e-01,   5.83447695e-01,   6.75334990e-01,
         3.00385147e-01,  -9.26487267e-01,  -7.62748837e-01,
        -1.56417713e-02,   7.31598675e-01,  -2.51127988e-01,
        -3.47031593e-01,  -1.02715504e+00,   9.45819855e-01,
         1.40280694e-01,   1.03019625e-01,  -4.75595206e-01,
         2.48817280e-01,   9.25343513e-01,  -2.07803279e-01,
        -2.07494959e-01,   4.35739547e-01,  -1.04279496e-01,
        -3.08422536e-01,   7.80537009e-01,  -8.67948651e-01,
         1.16043016e-01,  -3.76683742e-01,  -4.76608366e-01,
         7.38691151e-01,  -8.75270903e-01,  -8.57649446e-01,
         6.21079683e-01,   2.47170091e-01,  -7.57386565e-01,
         7.37396181e-01,   5.21465857e-03,   5.16199112e-01,
         7.28622079e-04,  -6.79813385e-01,  -9.31952149e-02,
         1.06434941e+00,   8.75964820e-01,  -6.31770194e-01,
        -2.75992483e-01,   9.85013172e-02,  -4.79730874e-01,
         8.04522276e-01,  -4.20392275e-01,   9.56339836e-01,
        -6.56542480e-01,   6.78624749e-01,  -8.87686133e-01,
         2.86612689e-01,   4.63127762e-01,   2.50756174e-01,
        -9.03050780e-01,  -2.89407164e-01,  -2.27195710e-01,
        -7.09807798e-02,   1.16347186e-01,  -7.90337563e-01,
         1.90274835e-01,   8.13606739e-01,   8.70082617e-01,
        -7.19924569e-01,  -3.46752524e-01,   9.34063673e-01,
        -4.35349911e-01,   2.15740979e-01,  -1.56441346e-01,
        -6.64027989e-01,   1.07105434e+00,   4.54691857e-01,
         7.86813021e-01,   2.07491629e-02,  -5.59872270e-01,
         6.17382050e-01,  -7.14152455e-01,   3.70913506e-01,
        -3.29529047e-01,   1.33344904e-01,  -9.51524734e-01,
        -4.52669591e-01,  -1.47972375e-01,   7.83605099e-01,
        -1.07437992e+00,   9.49787021e-01,   4.42183703e-01,
         2.73181438e-01,   5.06286249e-02,  -9.45496261e-01,
         1.92179978e-01,   7.90552139e-01,   8.87426674e-01,
         6.54974580e-02,   1.06308472e+00,   8.20663810e-01,
        -4.18068916e-01,  -3.74885380e-01,  -3.89595896e-01,
        -5.61836958e-02,   3.22974056e-01,   9.99552608e-01,
        -2.09551007e-01,   8.43216062e-01,  -6.11203253e-01,
         3.80809546e-01,   4.59313959e-01,   6.26192451e-01,
        -8.46500158e-01,   7.35134184e-01,  -6.93990707e-01,
         1.00871849e+00,  -5.95232427e-01,   4.20968771e-01,
         6.81039393e-02,   2.93322444e-01,   7.50425339e-01,
        -2.51413614e-01,  -8.67973983e-01,  -1.08505833e+00,
         4.74990159e-01,   3.72020781e-01,   6.85470581e-01,
        -1.01328218e+00,  -3.08616340e-01], dtype=float32)

In [6]:
normalize(_)


/home/hoangnt/anaconda3/lib/python3.5/site-packages/sklearn/utils/validation.py:386: DeprecationWarning: Passing 1d arrays as data is deprecated in 0.17 and willraise ValueError in 0.19. Reshape your data either using X.reshape(-1, 1) if your data has a single feature or X.reshape(1, -1) if it contains a single sample.
  DeprecationWarning)
Out[6]:
array([[ -3.86642106e-02,   1.24195710e-01,   6.73019662e-02,
          6.03038296e-02,   8.37536231e-02,   9.69440043e-02,
          4.31201383e-02,  -1.32996783e-01,  -1.09492213e-01,
         -2.24536844e-03,   1.05020627e-01,  -3.60492989e-02,
         -4.98162135e-02,  -1.47447601e-01,   1.35771975e-01,
          2.01372243e-02,   1.47884162e-02,  -6.82714581e-02,
          3.57175954e-02,   1.32832602e-01,  -2.98300572e-02,
         -2.97857989e-02,   6.25501946e-02,  -1.49692697e-02,
         -4.42739017e-02,   1.12045698e-01,  -1.24593593e-01,
          1.66579168e-02,  -5.40727638e-02,  -6.84168935e-02,
          1.06038749e-01,  -1.25644699e-01,  -1.23115152e-01,
          8.91556814e-02,   3.54811437e-02,  -1.08722463e-01,
          1.05852857e-01,   7.48561637e-04,   7.41001219e-02,
          1.04593339e-04,  -9.75868702e-02,  -1.33781265e-02,
          1.52786821e-01,   1.25744313e-01,  -9.06902999e-02,
         -3.96185853e-02,   1.41398152e-02,  -6.88651279e-02,
          1.15488775e-01,  -6.03471026e-02,   1.37282103e-01,
         -9.42463428e-02,   9.74162444e-02,  -1.27426907e-01,
          4.11431082e-02,   6.64817616e-02,   3.59959230e-02,
         -1.29632488e-01,  -4.15442549e-02,  -3.26138325e-02,
         -1.01892557e-02,   1.67015810e-02,  -1.13452561e-01,
          2.73138583e-02,   1.16792843e-01,   1.24899924e-01,
         -1.03344813e-01,  -4.97761518e-02,   1.34084374e-01,
         -6.24942631e-02,   3.09695099e-02,  -2.24570781e-02,
         -9.53208879e-02,   1.53749317e-01,   6.52707890e-02,
          1.12946615e-01,   2.97853211e-03,  -8.03693831e-02,
          8.86248872e-02,  -1.02516226e-01,   5.32444492e-02,
         -4.73037325e-02,   1.91415939e-02,  -1.36590898e-01,
         -6.49804920e-02,  -2.12413613e-02,   1.12486124e-01,
         -1.54226705e-01,   1.36341453e-01,   6.34752512e-02,
          3.92150581e-02,   7.26771401e-03,  -1.35725513e-01,
          2.75873393e-02,   1.13483369e-01,   1.27389655e-01,
          9.40212794e-03,   1.52605280e-01,   1.17805883e-01,
         -6.00135848e-02,  -5.38146086e-02,  -5.59262969e-02,
         -8.06514174e-03,   4.63627651e-02,   1.43485278e-01,
         -3.00809424e-02,   1.21043243e-01,  -8.77379254e-02,
          5.46650216e-02,   6.59342930e-02,   8.98896158e-02,
         -1.21514678e-01,   1.05528146e-01,  -9.96220186e-02,
          1.44801036e-01,  -8.54453221e-02,   6.04298562e-02,
          9.77628678e-03,   4.21062894e-02,   1.07723184e-01,
         -3.60902995e-02,  -1.24597237e-01,  -1.55759588e-01,
          6.81845993e-02,   5.34033962e-02,   9.83989611e-02,
         -1.45456150e-01,  -4.43017222e-02]], dtype=float32)

In [7]:
emb[0]


Out[7]:
array([ -3.86642106e-02,   1.24195710e-01,   6.73019662e-02,
         6.03038296e-02,   8.37536231e-02,   9.69440043e-02,
         4.31201383e-02,  -1.32996783e-01,  -1.09492213e-01,
        -2.24536844e-03,   1.05020627e-01,  -3.60492989e-02,
        -4.98162135e-02,  -1.47447601e-01,   1.35771975e-01,
         2.01372243e-02,   1.47884162e-02,  -6.82714581e-02,
         3.57175954e-02,   1.32832602e-01,  -2.98300572e-02,
        -2.97857989e-02,   6.25501946e-02,  -1.49692697e-02,
        -4.42739017e-02,   1.12045698e-01,  -1.24593593e-01,
         1.66579168e-02,  -5.40727638e-02,  -6.84168935e-02,
         1.06038749e-01,  -1.25644699e-01,  -1.23115152e-01,
         8.91556814e-02,   3.54811437e-02,  -1.08722463e-01,
         1.05852857e-01,   7.48561637e-04,   7.41001219e-02,
         1.04593339e-04,  -9.75868702e-02,  -1.33781265e-02,
         1.52786821e-01,   1.25744313e-01,  -9.06902999e-02,
        -3.96185853e-02,   1.41398152e-02,  -6.88651279e-02,
         1.15488775e-01,  -6.03471026e-02,   1.37282103e-01,
        -9.42463428e-02,   9.74162444e-02,  -1.27426907e-01,
         4.11431082e-02,   6.64817616e-02,   3.59959230e-02,
        -1.29632488e-01,  -4.15442549e-02,  -3.26138325e-02,
        -1.01892557e-02,   1.67015810e-02,  -1.13452561e-01,
         2.73138583e-02,   1.16792843e-01,   1.24899924e-01,
        -1.03344813e-01,  -4.97761518e-02,   1.34084374e-01,
        -6.24942631e-02,   3.09695099e-02,  -2.24570781e-02,
        -9.53208879e-02,   1.53749317e-01,   6.52707890e-02,
         1.12946615e-01,   2.97853211e-03,  -8.03693831e-02,
         8.86248872e-02,  -1.02516226e-01,   5.32444492e-02,
        -4.73037325e-02,   1.91415939e-02,  -1.36590898e-01,
        -6.49804920e-02,  -2.12413613e-02,   1.12486124e-01,
        -1.54226705e-01,   1.36341453e-01,   6.34752512e-02,
         3.92150581e-02,   7.26771401e-03,  -1.35725513e-01,
         2.75873393e-02,   1.13483369e-01,   1.27389655e-01,
         9.40212794e-03,   1.52605280e-01,   1.17805883e-01,
        -6.00135848e-02,  -5.38146086e-02,  -5.59262969e-02,
        -8.06514174e-03,   4.63627651e-02,   1.43485278e-01,
        -3.00809424e-02,   1.21043243e-01,  -8.77379254e-02,
         5.46650216e-02,   6.59342930e-02,   8.98896158e-02,
        -1.21514678e-01,   1.05528146e-01,  -9.96220186e-02,
         1.44801036e-01,  -8.54453221e-02,   6.04298562e-02,
         9.77628678e-03,   4.21062894e-02,   1.07723184e-01,
        -3.60902995e-02,  -1.24597237e-01,  -1.55759588e-01,
         6.81845993e-02,   5.34033962e-02,   9.83989611e-02,
        -1.45456150e-01,  -4.43017222e-02], dtype=float32)

In [9]:
x_train, yl_train, x_test, yl_test = bc.get_ids_labels(0.5)

In [10]:
X_train = [emb[i] for i in x_train]
Y_train = MultiLabelBinarizer().fit_transform(yl_train)

In [11]:
Y_train.shape


Out[11]:
(5156, 39)

In [12]:
for i,j in bc._communities.items():
    if 39 in j:
        print(i)


14
691
1250
1344
1465
1550
4709
7759

There is only 8 nodes with community 39. This might cause a problem.


In [15]:
bc._communities[1465]


Out[15]:
[39]

In [65]:
lg = OneVsRestClassifier(LogisticRegression(C=1e5))

In [66]:
lg.fit(X_train, Y_train)


Out[66]:
OneVsRestClassifier(estimator=LogisticRegression(C=100000.0, class_weight=None, dual=False,
          fit_intercept=True, intercept_scaling=1, max_iter=100,
          multi_class='ovr', n_jobs=1, penalty='l2', random_state=None,
          solver='liblinear', tol=0.0001, verbose=0, warm_start=False),
          n_jobs=1)

In [33]:
lg.predict(emb[9566].reshape(1,-1))


Out[33]:
array([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]])

In [30]:
emb[5].dot(emb[0])


Out[30]:
0.097585022

In [31]:
x_train[0]


Out[31]:
9566

In [32]:
x_train[1]


Out[32]:
2378

In [38]:
Y_train[8]


Out[38]:
array([0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])

In [39]:
lg.predict_proba(emb[1234].reshape(1,-1))


Out[39]:
array([[ 0.02476453,  0.01237352,  0.07884584,  0.0149981 ,  0.08755578,
         0.12266222,  0.01220914,  0.15262878,  0.01222904,  0.01525664,
         0.02794357,  0.00305634,  0.06974116,  0.01687944,  0.01621081,
         0.04848325,  0.06587311,  0.01920561,  0.0962717 ,  0.02722684,
         0.01389558,  0.04015614,  0.05028069,  0.09708167,  0.01009068,
         0.05425601,  0.006924  ,  0.00494262,  0.01795537,  0.025277  ,
         0.00711691,  0.02192811,  0.02125015,  0.0056504 ,  0.01110642,
         0.01373995,  0.01219028,  0.00320059,  0.00219195]])

In [40]:
bc._communities[1234]


Out[40]:
[8]

In [56]:
lg.predict_proba(emb[1234].reshape(1,-1)).argsort()[0][-4:]


Out[56]:
array([18, 23,  5,  7])

In [60]:
lg.predict_proba(emb[5437].reshape(1,-1)).argsort()[0]


Out[60]:
array([38, 20, 11, 34, 33, 37, 36, 26, 30, 27,  3, 10, 14,  8, 13, 35,  1,
       19, 32, 28,  0, 24, 31,  9, 17, 25, 15, 21, 22, 29,  6,  2,  4, 12,
       18, 16,  5,  7, 23])

In [58]:
bc._communities[5437]


Out[58]:
[33]

In [61]:
for i in bc[5437]:
    print(bc._communities[i])


[5, 6, 7, 24]
[9]
[8, 13, 17]
[4]
[3, 31, 32]
[13, 17, 33]

In [62]:
bc[5437]


Out[62]:
[991, 4839, 5354, 6753, 6832, 7999]

In [72]:
for i in bc[7999]:
    if 32 in bc._communities[i]:
        print(i)


5379
6066
6984

In [77]:
lg.predict_proba(emb[6984].reshape(1,-1))[0].argmax()


Out[77]:
4

In [78]:
for x in [14,691,1250,1344,1465,1550,4709,7759]:
    if x in x_train:
        print('la')


la
la
la
la

In [ ]: